Skip to content

Commit 35bc088

Browse files
committed
Make default system prompt configurable on web
1 parent 81ed1cf commit 35bc088

19 files changed

+349
-123
lines changed

llama.cpp/common.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
276276
// TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context.
277277
params.seed = std::stoul(argv[i]);
278278
sparams.seed = std::stoul(argv[i]);
279+
FLAG_seed = sparams.seed; // [jart]
279280
return true;
280281
}
281282
if (arg == "-t" || arg == "--threads") {
@@ -490,17 +491,20 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
490491
if (arg == "--top-p") {
491492
CHECK_ARG
492493
sparams.top_p = std::stof(argv[i]);
494+
FLAG_top_p = sparams.top_p; // [jart]
493495
return true;
494496
}
495497
if (arg == "--min-p") {
496498
CHECK_ARG
497499
sparams.min_p = std::stof(argv[i]);
498500
return true;
499501
}
500-
if (arg == "--temp") {
502+
if (arg == "--temp" || //
503+
arg == "--temperature") { // [jart]
501504
CHECK_ARG
502505
sparams.temp = std::stof(argv[i]);
503506
sparams.temp = std::max(sparams.temp, 0.0f);
507+
FLAG_temperature = sparams.temp; // [jart]
504508
return true;
505509
}
506510
if (arg == "--tfs") {
@@ -527,11 +531,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
527531
if (arg == "--frequency-penalty") {
528532
CHECK_ARG
529533
sparams.penalty_freq = std::stof(argv[i]);
534+
FLAG_frequency_penalty = sparams.penalty_freq; // [jart]
530535
return true;
531536
}
532537
if (arg == "--presence-penalty") {
533538
CHECK_ARG
534539
sparams.penalty_present = std::stof(argv[i]);
540+
FLAG_presence_penalty = sparams.penalty_present; // [jart]
535541
return true;
536542
}
537543
if (arg == "--dynatemp-range") {
@@ -903,8 +909,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
903909
params.verbose_prompt = true;
904910
return true;
905911
}
906-
if (arg == "--no-display-prompt" || arg == "--silent-prompt") {
912+
if (arg == "--no-display-prompt" || //
913+
arg == "--silent-prompt") { // [jart]
907914
params.display_prompt = false;
915+
FLAG_no_display_prompt = true; // [jart]
916+
return true;
917+
}
918+
if (arg == "--display-prompt") { // [jart]
919+
params.display_prompt = true;
920+
FLAG_no_display_prompt = false;
908921
return true;
909922
}
910923
if (arg == "-r" || arg == "--reverse-prompt") {

llama.cpp/server/server.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2593,6 +2593,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
25932593
{
25942594
FLAG_nologo = true;
25952595
}
2596+
else if (arg == "--no-display-prompt" || //
2597+
arg == "--silent-prompt")
2598+
{
2599+
FLAG_no_display_prompt = true;
2600+
}
2601+
else if (arg == "--display-prompt")
2602+
{
2603+
FLAG_no_display_prompt = false;
2604+
}
25962605
else if (arg == "--trap")
25972606
{
25982607
FLAG_trap = true;

llamafile/flags.cpp

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ bool FLAG_iq = false;
4343
bool FLAG_log_disable = false;
4444
bool FLAG_mlock = false;
4545
bool FLAG_mmap = true;
46+
bool FLAG_no_display_prompt = false;
4647
bool FLAG_nocompile = false;
4748
bool FLAG_nologo = false;
4849
bool FLAG_precise = false;
@@ -59,7 +60,10 @@ const char *FLAG_prompt = nullptr;
5960
const char *FLAG_url_prefix = "";
6061
const char *FLAG_www_root = "/zip/www";
6162
double FLAG_token_rate = 1;
62-
float FLAG_temp = 0.8;
63+
float FLAG_frequency_penalty = 0;
64+
float FLAG_presence_penalty = 0;
65+
float FLAG_temperature = .8;
66+
float FLAG_top_p = .95;
6367
int FLAG_batch = 2048;
6468
int FLAG_ctx_size = 8192;
6569
int FLAG_flash_attn = false;
@@ -69,7 +73,6 @@ int FLAG_http_obuf_size = 1024 * 1024;
6973
int FLAG_keepalive = 5;
7074
int FLAG_main_gpu = 0;
7175
int FLAG_n_gpu_layers = -1;
72-
int FLAG_seed = LLAMA_DEFAULT_SEED;
7376
int FLAG_slots = 1;
7477
int FLAG_split_mode = LLAMA_SPLIT_MODE_LAYER;
7578
int FLAG_threads = MIN(cpu_get_num_math(), 20);
@@ -80,6 +83,7 @@ int FLAG_ubatch = 512;
8083
int FLAG_verbose = 0;
8184
int FLAG_warmup = true;
8285
int FLAG_workers;
86+
unsigned FLAG_seed = LLAMA_DEFAULT_SEED;
8387

8488
std::vector<std::string> FLAG_headers;
8589

@@ -153,6 +157,17 @@ void llamafile_get_flags(int argc, char **argv) {
153157
continue;
154158
}
155159

160+
if (!strcmp(flag, "--no-display-prompt") || //
161+
!strcmp(flag, "--silent-prompt")) {
162+
FLAG_no_display_prompt = true;
163+
continue;
164+
}
165+
166+
if (!strcmp(flag, "--display-prompt")) {
167+
FLAG_no_display_prompt = false;
168+
continue;
169+
}
170+
156171
//////////////////////////////////////////////////////////////////////
157172
// server flags
158173

@@ -278,6 +293,45 @@ void llamafile_get_flags(int argc, char **argv) {
278293
continue;
279294
}
280295

296+
//////////////////////////////////////////////////////////////////////
297+
// sampling flags
298+
299+
if (!strcmp(flag, "--seed")) {
300+
if (i == argc)
301+
missing("--seed");
302+
FLAG_seed = strtol(argv[i++], 0, 0);
303+
continue;
304+
}
305+
306+
if (!strcmp(flag, "--temp") || //
307+
!strcmp(flag, "--temperature")) {
308+
if (i == argc)
309+
missing("--temp");
310+
FLAG_temperature = atof(argv[i++]);
311+
continue;
312+
}
313+
314+
if (!strcmp(flag, "--top-p")) {
315+
if (i == argc)
316+
missing("--top-p");
317+
FLAG_top_p = atof(argv[i++]);
318+
continue;
319+
}
320+
321+
if (!strcmp(flag, "--frequency-penalty")) {
322+
if (i == argc)
323+
missing("--frequency-penalty");
324+
FLAG_frequency_penalty = atof(argv[i++]);
325+
continue;
326+
}
327+
328+
if (!strcmp(flag, "--presence-penalty")) {
329+
if (i == argc)
330+
missing("--presence-penalty");
331+
FLAG_presence_penalty = atof(argv[i++]);
332+
continue;
333+
}
334+
281335
//////////////////////////////////////////////////////////////////////
282336
// model flags
283337

@@ -319,20 +373,6 @@ void llamafile_get_flags(int argc, char **argv) {
319373
continue;
320374
}
321375

322-
if (!strcmp(flag, "--seed")) {
323-
if (i == argc)
324-
missing("--seed");
325-
FLAG_seed = atoi(argv[i++]);
326-
continue;
327-
}
328-
329-
if (!strcmp(flag, "--temp")) {
330-
if (i == argc)
331-
missing("--temp");
332-
FLAG_temp = atof(argv[i++]);
333-
continue;
334-
}
335-
336376
if (!strcmp(flag, "-t") || !strcmp(flag, "--threads")) {
337377
if (i == argc)
338378
missing("--threads");

llamafile/llamafile.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ extern bool FLAG_iq;
1313
extern bool FLAG_log_disable;
1414
extern bool FLAG_mlock;
1515
extern bool FLAG_mmap;
16+
extern bool FLAG_no_display_prompt;
1617
extern bool FLAG_nocompile;
1718
extern bool FLAG_nologo;
1819
extern bool FLAG_precise;
@@ -30,7 +31,10 @@ extern const char *FLAG_prompt;
3031
extern const char *FLAG_url_prefix;
3132
extern const char *FLAG_www_root;
3233
extern double FLAG_token_rate;
33-
extern float FLAG_temp;
34+
extern float FLAG_frequency_penalty;
35+
extern float FLAG_presence_penalty;
36+
extern float FLAG_temperature;
37+
extern float FLAG_top_p;
3438
extern int FLAG_batch;
3539
extern int FLAG_ctx_size;
3640
extern int FLAG_flash_attn;
@@ -41,7 +45,6 @@ extern int FLAG_http_obuf_size;
4145
extern int FLAG_keepalive;
4246
extern int FLAG_main_gpu;
4347
extern int FLAG_n_gpu_layers;
44-
extern int FLAG_seed;
4548
extern int FLAG_slots;
4649
extern int FLAG_split_mode;
4750
extern int FLAG_threads;
@@ -52,6 +55,7 @@ extern int FLAG_ubatch;
5255
extern int FLAG_verbose;
5356
extern int FLAG_warmup;
5457
extern int FLAG_workers;
58+
extern unsigned FLAG_seed;
5559

5660
struct llamafile;
5761
struct llamafile *llamafile_open_gguf(const char *, const char *);

llamafile/server/client.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,8 @@ Client::dispatcher()
654654
return v1_chat_completions();
655655
if (p1 == "slotz")
656656
return slotz();
657+
if (p1 == "flagz")
658+
return flagz();
657659

658660
// serve static endpoints
659661
int infd;

llamafile/server/client.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct Client
6666
char* params_memory_ = nullptr;
6767
std::string_view payload_;
6868
std::string resolved_;
69+
std::string dump_;
6970
Cleanup* cleanups_;
7071
Buffer ibuf_;
7172
Buffer obuf_;
@@ -112,6 +113,7 @@ struct Client
112113
bool get_v1_chat_completions_params(V1ChatCompletionParams*) __wur;
113114

114115
bool slotz() __wur;
116+
bool flagz() __wur;
115117
};
116118

117119
} // namespace server

llamafile/server/embedding.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
#include <sys/resource.h>
2828
#include <vector>
2929

30+
using jt::Json;
31+
3032
namespace lf {
3133
namespace server {
3234

llamafile/server/flagz.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
2+
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
3+
//
4+
// Copyright 2024 Mozilla Foundation
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
18+
#include "client.h"
19+
#include "llama.cpp/llama.h"
20+
#include "llamafile/llamafile.h"
21+
#include "llamafile/server/json.h"
22+
23+
namespace lf {
24+
namespace server {
25+
26+
bool
27+
Client::flagz()
28+
{
29+
jt::Json json;
30+
json["prompt"] = FLAG_prompt;
31+
json["no_display_prompt"] = FLAG_no_display_prompt;
32+
json["nologo"] = FLAG_nologo;
33+
json["temperature"] = FLAG_temperature;
34+
json["presence_penalty"] = FLAG_presence_penalty;
35+
json["frequency_penalty"] = FLAG_frequency_penalty;
36+
if (FLAG_seed == LLAMA_DEFAULT_SEED) {
37+
json["seed"] = nullptr;
38+
} else {
39+
json["seed"] = FLAG_seed;
40+
}
41+
dump_ = json.toStringPretty();
42+
dump_ += '\n';
43+
char* p = append_http_response_message(obuf_.p, 200);
44+
p = stpcpy(p, "Content-Type: application/json\r\n");
45+
return send_response(obuf_.p, p, dump_);
46+
}
47+
48+
} // namespace server
49+
} // namespace lf

0 commit comments

Comments
 (0)